In [1]:
import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from dotenv import load_dotenv
from pandas.plotting import parallel_coordinates
import importlib
import plotly.express as px
import plotly.graph_objects as go
import os
import glob
import pandas as pd
import json
import utils.db_tools as db_tools
from utils.db_tools import (
get_db,
filter_df,
make_animation,
get_data,
metrics_grid,
plot_grid,
compute_metrics
)
from classify import classify_trajectories
importlib.reload(db_tools)
Out[1]:
<module 'utils.db_tools' from '/cluster/home/vogtva/pde-solvers-cuda/analysis/utils/db_tools.py'>
In [2]:
model = "bruss"
run_id = "ball_big"
load_dotenv()
data_dir = os.getenv("DATA_DIR")
output_dir = os.getenv("OUT_DIR")
df = pd.read_csv(f"{output_dir}/{model}/{run_id}/classification_metrics_02.csv")
df_class = classify_trajectories(
df
)
df = df_class.copy()
# df = df[df["filename"].apply(os.path.exists)].reset_index(drop=True)
df["op"] = df["original_point"].astype(str)
In [4]:
df["category"].value_counts()
Out[4]:
category SS 3007 OSC 2128 I 1676 DSS 387 Name: count, dtype: int64
In [9]:
# Plot the distribution of 'mean_deviation' for each category
plt.figure(figsize=(12, 8))
ax = sns.histplot(data=df_class, x='mean_deviation', hue='category', multiple='stack', kde=False)
plt.xlabel('Mean Deviation')
plt.ylabel('Frequency')
plt.title('Distribution of Mean Deviation by Category')
plt.show()
In [10]:
# plt.figure(figsize=(10, 6))
# sns.scatterplot(x=df_class['A'], y=df_class['B'], hue=df_class["category"])
# plt.xlabel('A')
# plt.ylabel('B')
# plt.title('Scatter plot of A vs B for Sampling Centers')
# plt.show()
fig = px.scatter(
df_class,
x="A",
y="B",
color="category",
title="Scatter plot of A vs B",
labels={"A": "A", "B": "B"},
width=800,
height=800,
)
# Display the plot in the notebook
fig.show()
In [15]:
df.value_counts("category")
Out[15]:
category SS 3957 OSC 1530 I 1294 DSS 419 Name: count, dtype: int64
In [14]:
def plot_ball_behavior(df, metric="dev"):
t = np.linspace(0, 100, 100)
title = ""
all_metrics = []
for _, row in df.iterrows():
d = get_data(row)
metrics = compute_metrics(row, start_frame=0)
if metric == "dev":
title = "Deviation"
values = metrics[0]
elif metric == "dt":
title = "Time Derivative"
values = metrics[1]
elif metric == "dx":
title = "Spatial Derivative"
values = metrics[2]
all_metrics.append(values)
# Convert to numpy array for easier computation
all_metrics = np.array(all_metrics)
# Compute mean and std
avg_metric = np.mean(all_metrics, axis=0)
min_metric = np.min(all_metrics, axis=0)
std_metric = np.std(all_metrics, axis=0)
# Create figure
fig = go.Figure()
# Add shaded area for standard deviation
# fig.add_trace(
# go.Scatter(
# x=np.concatenate([t, t[::-1]]),
# y=np.concatenate(
# [avg_metric + std_metric, (avg_metric)[::-1]]
# ),
# fill="toself",
# fillcolor="rgba(0,100,80,0.2)",
# line=dict(color="rgba(255,255,255,0)"),
# showlegend=False,
# )
# )
# # Add mean line
# fig.add_trace(
# go.Scatter(
# x=t,
# y=avg_metric,
# mode="lines",
# name=title,
# hovertemplate="Index: %{x}<br>Deviation: %{y:.2f}<extra></extra>",
# )
# )
# fig.add_trace(
# go.Scatter(
# x=t,
# y=min_metric,
# mode="lines",
# name="min",
# hovertemplate="Index: %{x}<br>Min: %{y:.2f}<extra></extra>",
# )
# )
fig.add_trace(
go.Scatter(
x=t,
y=values,
mode="lines",
name=f"Row {row.name}", # Use row index or a unique identifier
hovertemplate="Index: %{x}<br>Value: %{y:.2f}<extra></extra>",
)
)
# Update layout
fig.update_layout(
title="Deviation Metrics",
xaxis_title="Time Step/Index",
yaxis_title="Deviation Value",
hovermode="x unified",
showlegend=True,
template="plotly_white",
)
fig.show()
In [15]:
df_class["op"] = df_class["original_point"].astype(str)
for _, df1 in df_class.groupby("op"):
original_point = df1.iloc[0]["original_point"]
print(original_point, df1.value_counts("category").to_dict())
plot_ball_behavior(df1)
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 11} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 18} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 0.5, 'B': 0.625, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 11} {'SS': 59, 'I': 1}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 18} {'SS': 59, 'I': 1}
{'A': 0.5, 'B': 1.0, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 0.5, 'B': 1.0, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 11} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 18} {'OSC': 49, 'I': 11}
{'A': 0.5, 'B': 1.5, 'Du': 1, 'Dv': 4} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 12} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 33} {'OSC': 60}
{'A': 0.5, 'B': 1.5, 'Du': 3, 'Dv': 54} {'OSC': 55, 'I': 5}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 11} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 18} {'OSC': 51, 'I': 9}
{'A': 0.5, 'B': 2.0, 'Du': 1, 'Dv': 4} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 12} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 33} {'OSC': 60}
{'A': 0.5, 'B': 2.0, 'Du': 3, 'Dv': 54} {'OSC': 59, 'I': 1}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 11} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 18} {'SS': 57, 'DSS': 3}
{'A': 1.0, 'B': 1.25, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 1.0, 'B': 1.25, 'Du': 3, 'Dv': 54} {'SS': 60}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 11} {'I': 30, 'OSC': 13, 'SS': 12, 'DSS': 5}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 18} {'I': 30, 'OSC': 21, 'SS': 8, 'DSS': 1}
{'A': 1.0, 'B': 2.0, 'Du': 1, 'Dv': 4} {'OSC': 30, 'SS': 25, 'I': 5}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 12} {'OSC': 25, 'SS': 25, 'I': 10}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 33} {'I': 33, 'OSC': 16, 'SS': 11}
{'A': 1.0, 'B': 2.0, 'Du': 3, 'Dv': 54} {'I': 33, 'OSC': 14, 'SS': 13}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 11} {'OSC': 32, 'I': 27, 'DSS': 1}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'I': 23, 'DSS': 6}
{'A': 1.0, 'B': 3.0, 'Du': 1, 'Dv': 4} {'OSC': 51, 'I': 8, 'DSS': 1}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 12} {'OSC': 56, 'I': 4}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 33} {'OSC': 43, 'DSS': 9, 'I': 8}
{'A': 1.0, 'B': 3.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 21, 'DSS': 9}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 19, 'I': 11}
{'A': 1.0, 'B': 4.0, 'Du': 1, 'Dv': 4} {'OSC': 46, 'I': 14}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 12} {'OSC': 57, 'I': 3}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 33} {'OSC': 34, 'I': 26}
{'A': 1.0, 'B': 4.0, 'Du': 3, 'Dv': 54} {'OSC': 30, 'I': 29, 'DSS': 1}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 11} {'SS': 56, 'I': 4}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 18} {'SS': 40, 'I': 16, 'DSS': 4}
{'A': 1.5, 'B': 1.875, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 33} {'SS': 60}
{'A': 1.5, 'B': 1.875, 'Du': 3, 'Dv': 54} {'SS': 47, 'I': 13}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 11} {'I': 29, 'SS': 25, 'OSC': 4, 'DSS': 2}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 30}
{'A': 1.5, 'B': 3.0, 'Du': 1, 'Dv': 4} {'SS': 38, 'OSC': 13, 'I': 7, 'DSS': 2}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 12} {'SS': 36, 'I': 15, 'OSC': 9}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 33} {'I': 34, 'SS': 20, 'OSC': 6}
{'A': 1.5, 'B': 3.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 22, 'OSC': 8}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 29, 'DSS': 1}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 18} {'OSC': 30, 'I': 26, 'DSS': 4}
{'A': 1.5, 'B': 4.5, 'Du': 1, 'Dv': 4} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 4.5, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'DSS': 20, 'I': 10}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 24, 'I': 6}
{'A': 1.5, 'B': 6.0, 'Du': 1, 'Dv': 4} {'OSC': 43, 'I': 11, 'DSS': 6}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 1.5, 'B': 6.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 11} {'SS': 45, 'I': 13, 'DSS': 2}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 24, 'DSS': 6}
{'A': 2.0, 'B': 2.5, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 33} {'SS': 50, 'I': 10}
{'A': 2.0, 'B': 2.5, 'Du': 3, 'Dv': 54} {'SS': 30, 'I': 28, 'DSS': 2}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 11} {'I': 30, 'SS': 30}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 28, 'OSC': 2}
{'A': 2.0, 'B': 4.0, 'Du': 1, 'Dv': 4} {'SS': 39, 'I': 20, 'DSS': 1}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 12} {'SS': 42, 'I': 17, 'DSS': 1}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 33} {'I': 31, 'SS': 29}
{'A': 2.0, 'B': 4.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'I': 17, 'DSS': 13}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 19, 'I': 11}
{'A': 2.0, 'B': 6.0, 'Du': 1, 'Dv': 4} {'OSC': 33, 'I': 22, 'DSS': 3, 'SS': 2}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 12} {'OSC': 36, 'I': 24}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 6.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 28, 'SS': 2}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 11} {'OSC': 30, 'DSS': 20, 'I': 10}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 18} {'OSC': 30, 'DSS': 28, 'I': 2}
{'A': 2.0, 'B': 8.0, 'Du': 1, 'Dv': 4} {'OSC': 32, 'DSS': 17, 'I': 10}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 12} {'OSC': 45, 'I': 15}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 33} {'I': 30, 'OSC': 30}
{'A': 2.0, 'B': 8.0, 'Du': 3, 'Dv': 54} {'I': 30, 'OSC': 30}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 11} {'SS': 30, 'I': 29, 'DSS': 1}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 25, 'DSS': 5}
{'A': 5.0, 'B': 10.0, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 12} {'SS': 59, 'I': 1}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 33} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 10.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 11} {'SS': 30, 'DSS': 18, 'I': 12}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'DSS': 23, 'I': 7}
{'A': 5.0, 'B': 15.0, 'Du': 1, 'Dv': 4} {'SS': 33, 'I': 17, 'DSS': 10}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 12} {'SS': 30, 'I': 29, 'DSS': 1}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 33} {'SS': 30, 'I': 26, 'DSS': 4}
{'A': 5.0, 'B': 15.0, 'Du': 3, 'Dv': 54} {'SS': 30, 'I': 24, 'DSS': 6}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 11} {'DSS': 29, 'SS': 28, 'OSC': 2, 'I': 1}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 18} {'SS': 30, 'I': 19, 'DSS': 11}
{'A': 5.0, 'B': 20.0, 'Du': 1, 'Dv': 4} {'SS': 30, 'DSS': 25, 'I': 5}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 12} {'SS': 30, 'I': 26, 'DSS': 4}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 33} {'SS': 30, 'DSS': 18, 'I': 12}
{'A': 5.0, 'B': 20.0, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 11} {'SS': 46, 'I': 14}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 18} {'I': 30, 'SS': 30}
{'A': 5.0, 'B': 6.25, 'Du': 1, 'Dv': 4} {'SS': 60}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 12} {'SS': 60}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 33} {'SS': 41, 'I': 17, 'DSS': 2}
{'A': 5.0, 'B': 6.25, 'Du': 3, 'Dv': 54} {'I': 30, 'SS': 30}
In [ ]: